import os
import math
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
from scipy import ndarray
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
from collections import Counter
from imblearn.datasets import fetch_datasets
import keras
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from tqdm import tqdm
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
from keras.optimizers import RMSprop
from keras import losses
from keras import backend as K
import tensorflow as tf
import warnings
warnings.filterwarnings("ignore")
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
data = fetch_datasets()['yeast_me2']
labels=data.target ## labels of the data
labels.shape
(1484,)
features=data.data ## features of the data
features.shape
(1484, 8)
n_feat=len(features[1]) ## number of features, this is later used in the model architecture also
label_1=np.where(labels == 1)[0] ## minority indexes
label_1=list(label_1)
len(label_1)
51
features_1=features[label_1] ## minority points
## dividing minority points into training and test sets
features_1_trn=features_1[list(range(0,math.ceil(len(features_1)*2/3)))]
features_1_tst=features_1[list(range(math.ceil(len(features_1)*2/3),len(features_1)))]
len(features_1_trn)
34
len(features_1_tst)
17
label_0=np.where(labels == -1)[0] ## majority indexes
label_0=list(label_0)
len(label_0)
1433
features_0=features[label_0] ## majority points
## division of majority into train and test sets
features_0_trn=features_0[list(range(0,math.ceil(len(features_0)*2/3)))]
features_0_tst=features_0[list(range(math.ceil(len(features_0)*2/3), len(features_0)))]
len(features_0_trn)
956
len(features_0_tst)
477
training_data=np.concatenate((features_1_trn,features_0_trn)) ## concatenating majority and minority for training data
test_data=np.concatenate((features_1_tst,features_0_tst)) ## concatenating majority and minority for test data
len(features_1_trn)
34
training_labels=np.concatenate((np.zeros(len(features_1_trn))+1, np.zeros(len(features_0_trn)))) ## generating traing labels
test_labels=np.concatenate((np.zeros(len(features_1_tst))+1, np.zeros(len(features_0_tst)))) ## generating test labels
training_labels.shape
(990,)
len(features_1_trn)
34
Until now we have obtained the data. We divided it into training and test sets. we separated obtained seperate variables for the majority and miority classes and their labels for both sets.
Some Notes:
neb=gen=5
We are generating the 'labels' array now, an array of size 2xgen. This array will later be used as batch labels to train the discriminator (See Figure)
labels=[]
for i in range(2*gen):
if i<gen:
labels.append(np.array([1,0]))
else:
labels.append(np.array([0,1]))
labels=np.array(labels)
labels=tf.convert_to_tensor(labels)
def BMB(data_min,data_maj, neb, gen):
## Generate a borderline majority batch
## data_min -> minority class data
## data_maj -> majority class data
## neb -> oversampling neighbourhood
## gen -> convex combinations generated from each neighbourhood
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import shuffle
neigh = NearestNeighbors(neb)
n_feat=data_min.shape[1]
neigh.fit(data_maj)
bmbi=[]
for i in range(len(data_min)):
indices=neigh.kneighbors([data_min[i]],neb,return_distance=False)
bmbi.append(indices)
bmbi=np.unique(np.array(bmbi).flatten())
bmbi=shuffle(bmbi)
bmb=features_0_trn[np.random.randint(len(features_0_trn),size=gen)]
bmb=tf.convert_to_tensor(bmb)
return bmb
An example of the function BMB generating a borderline majority neighbourhood of size 5. The majority class is first analyzed to find the borderline samples. This is the subset of majority class samples that are in the neb-earest neighbour set of at least one minority class sample. We call this function everytime we want to input such a batch as a part of the discriminator input. The other part of the discriminator input is to be generated by the generator. (See Figure)
BMB(features_1_trn,features_0_trn,5,gen)
<tf.Tensor: shape=(5, 8), dtype=float64, numpy=
array([[0.45, 0.51, 0.5 , 0.16, 0.5 , 0. , 0.52, 1. ],
[0.38, 0.49, 0.54, 1. , 0.5 , 0. , 0.5 , 0.27],
[0.47, 0.49, 0.63, 0.17, 0.5 , 0. , 0.31, 0.26],
[0.43, 0.39, 0.59, 0.13, 0.5 , 0. , 0.41, 0.43],
[0.45, 0.56, 0.51, 0.64, 0.5 , 0. , 0.54, 0.22]])>
def NMB(data_min, neb):
## Generate a minority neighbourhood batch
## data_min -> minority class data
## neb -> oversampling neighbourhood
## gen -> convex combinations generated from each neighbourhood
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import shuffle
neigh = NearestNeighbors(neb)
neigh.fit(data_min)
ind=np.random.randint(len(data_min))
nmbi=neigh.kneighbors([data_min[ind]],neb,return_distance=False)
nmbi=shuffle(nmbi)
nmb=features_1_trn[nmbi]
nmb=tf.convert_to_tensor(nmb[0])
return (nmb)
An example of the function NMB generating a minority neighbourhood of size 5. We call this function everytime we want to input such a batch to the generator network
NMB(features_1_trn, neb)
<tf.Tensor: shape=(5, 8), dtype=float64, numpy=
array([[0.83, 0.57, 0.42, 0.57, 0.5 , 0. , 0.5 , 0.32],
[0.94, 0.6 , 0.33, 0.49, 0.5 , 0. , 0.54, 0.22],
[0.66, 0.66, 0.32, 0.54, 0.5 , 0. , 0.51, 0.22],
[0.74, 0.56, 0.35, 0.36, 0.5 , 0. , 0.51, 0.25],
[0.91, 0.69, 0.57, 0.45, 0.5 , 0. , 0.56, 0.22]])>
from IPython.display import Image
Image(filename='CoSPOV.jpg')
def conv_sample_gen():
## the generator network to generate synthetic samples from the convex space of arbitrary minority neighbourhoods
min_neb_batch = keras.layers.Input(shape=(n_feat,)) ## takes minority batch as input
x=tf.reshape(min_neb_batch, (1,neb,n_feat), name=None) ## reshaping the 2D tensor to 3D for using 1-D convolution, otherwise 1-D convolution won't work.
x= keras.layers.Conv1D(n_feat, 3, activation='relu')(x) ## using 1-D convolution, feature dimension remains the same
x= keras.layers.Flatten()(x) ## flatten after convolution
x= keras.layers.Dense(neb*gen, activation='relu')(x) ## add dense layer to transform the vector to a convenient dimension
x= keras.layers.Reshape((neb,gen))(x)## again, witching to 2-D tensor once we have the convenient shape
s=K.sum(x,axis=1) ## row wise sum
s_non_zero=tf.keras.layers.Lambda(lambda x: x+.0001)(s) ## adding a small constant to always ensure the row sums are non zero. if this is not done then during initialization the sum can be zero
sinv=tf.math.reciprocal(s_non_zero) ## reprocals of the approximated row sum
x=keras.layers.Multiply()([sinv,x]) ## At this step we ensure that row sum is 1 for every row in x. That means, each row is set of convex co-efficient
aff=tf.transpose(x[0]) ## Now we transpose the matrix. So each column is now a set of convex coefficients
synth=tf.matmul(aff,min_neb_batch) ## We now do matrix multiplication of the affine combinations with the original minority batch taken as input. This generates a convex transformation of the input minority batch
model = Model(inputs=min_neb_batch, outputs=synth) ## finally we compile the generator with an arbitrary minortiy neighbourhood batch as input and a covex space transformation of the same number of samples as output
opt = keras.optimizers.Adam(learning_rate=0.001)
model.compile(loss='mean_squared_logarithmic_error', optimizer=opt)
return model
##instanciate network and visualize architecture
conv_sample_generator=conv_sample_gen()
conv_sample_generator.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 8)] 0
__________________________________________________________________________________________________
tf.reshape (TFOpLambda) (1, 5, 8) 0 input_1[0][0]
__________________________________________________________________________________________________
conv1d (Conv1D) (1, 3, 8) 200 tf.reshape[0][0]
__________________________________________________________________________________________________
flatten (Flatten) (1, 24) 0 conv1d[0][0]
__________________________________________________________________________________________________
dense (Dense) (1, 25) 625 flatten[0][0]
__________________________________________________________________________________________________
reshape (Reshape) (1, 5, 5) 0 dense[0][0]
__________________________________________________________________________________________________
tf.math.reduce_sum (TFOpLambda) (1, 5) 0 reshape[0][0]
__________________________________________________________________________________________________
lambda (Lambda) (1, 5) 0 tf.math.reduce_sum[0][0]
__________________________________________________________________________________________________
tf.math.reciprocal (TFOpLambda) (1, 5) 0 lambda[0][0]
__________________________________________________________________________________________________
multiply (Multiply) (1, 5, 5) 0 tf.math.reciprocal[0][0]
reshape[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem (Slici (5, 5) 0 multiply[0][0]
__________________________________________________________________________________________________
tf.compat.v1.transpose (TFOpLam (5, 5) 0 tf.__operators__.getitem[0][0]
__________________________________________________________________________________________________
tf.linalg.matmul (TFOpLambda) (5, 8) 0 tf.compat.v1.transpose[0][0]
input_1[0][0]
==================================================================================================
Total params: 825
Trainable params: 825
Non-trainable params: 0
__________________________________________________________________________________________________
def maj_min_disc():
## the discriminator is trained intwo phase:
## first phase: while training GAN the discriminator learns to differentiate synthetic minority samples generated from convex minority data space against the borderline majority samples
## second phase: after the GAN generator learns to create synthetic samples, it can be used to generate synthetic samples to balance the dataset
## and then rettrain the discriminator with the balanced dataset
samples=keras.layers.Input(shape=(n_feat,)) ## takes as input synthetic sample generated as input stacked upon a batch of borderline majority samples
y= keras.layers.Dense(250, activation='relu')(samples) ## passed through two dense layers
y= keras.layers.Dense(125, activation='relu')(y)
output= keras.layers.Dense(2, activation='sigmoid')(y) ## two output nodes. outputs have to be one-hot coded (see labels variable before)
model = Model(inputs=samples, outputs=output) ## compile model
opt = keras.optimizers.Adam(learning_rate=0.0001)
model.compile(loss='binary_crossentropy', optimizer=opt)
return model
##instanciate network and visualize architecture
maj_min_discriminator=maj_min_disc()
maj_min_discriminator.summary()
Model: "model_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) [(None, 8)] 0 _________________________________________________________________ dense_1 (Dense) (None, 250) 2250 _________________________________________________________________ dense_2 (Dense) (None, 125) 31375 _________________________________________________________________ dense_3 (Dense) (None, 2) 252 ================================================================= Total params: 33,877 Trainable params: 33,877 Non-trainable params: 0 _________________________________________________________________
def convGAN(conv_coeff_generator,maj_min_discriminator):
## for joining the generator and the discriminator
## conv_coeff_generator-> generator network instance
## maj_min_discriminator -> discriminator network instance
maj_min_disc.trainable=False ## by default the discriminator trainability is switched off.
## Thus training the GAN means training the generator network as per previously trained discriminator network.
batch_data = keras.layers.Input(shape=(n_feat,)) ## input receives a neighbourhood minority batch and a proximal majority batch concatenated
min_batch = tf.keras.layers.Lambda(lambda x: x[:neb])(batch_data) ## extract minority batch
maj_batch = tf.keras.layers.Lambda(lambda x: x[neb:])(batch_data) ## extract majority batch
conv_samples=conv_sample_generator(min_batch) ## pass minority batch into generator to obtain convex space transformation (synthetic samples) of the minority neighbourhood input batch
new_samples=tf.concat([conv_samples,maj_batch],axis=0) ## concatenate the synthetic samples with the majority samples
output=maj_min_discriminator(new_samples) ## pass the concatenated vector into the discriminator to know its decisions
## note that, the discriminator will not be traied but will make decisions based on its previous training while using this function
model = Model(inputs=batch_data, outputs=output)
opt = keras.optimizers.Adam(learning_rate=0.0001)
model.compile(loss='mse', optimizer=opt)
return model
## instanciate network and visualize architecture
cg=convGAN(conv_sample_generator,maj_min_discriminator)
cg.summary()
Model: "model_2"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_3 (InputLayer) [(None, 8)] 0
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 8) 0 input_3[0][0]
__________________________________________________________________________________________________
model (Functional) (5, 8) 825 lambda_1[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda) (None, 8) 0 input_3[0][0]
__________________________________________________________________________________________________
tf.concat (TFOpLambda) (None, 8) 0 model[0][0]
lambda_2[0][0]
__________________________________________________________________________________________________
model_1 (Functional) (None, 2) 33877 tf.concat[0][0]
==================================================================================================
Total params: 34,702
Trainable params: 34,702
Non-trainable params: 0
__________________________________________________________________________________________________
## this is the main training process where the GAn learns to generate appropriate samples from the convex space
## this is the first training phase for the discriminator and the only training phase for the generator.
step=1
runs=300
loss_history=[] ## this is for stroring the loss for every run
while step<runs+1:
min_batch=NMB(features_1_trn, neb) ## generate random minority neighbourhood batch
maj_batch=BMB(features_1_trn,features_0_trn,neb,gen) ## generate random proximal majority batch
conv_samples=conv_sample_generator.predict(min_batch) ## generate synthetic samples from convex space of minority neighbourhood batch using generator
concat_sample=tf.concat([conv_samples,maj_batch],axis=0) ## concatenate them with the majority batch
maj_min_discriminator.trainable=True ## switch on discriminator training
maj_min_discriminator.fit(x=concat_sample,y=labels,verbose=0) ## train the discriminator with the concatenated samples and the one-hot encoded labels
maj_min_discriminator.trainable=False ## switch off the discriminator training again
gan_loss_history=cg.fit(concat_sample,y=labels,verbose=0) ## use the GAN to make the generator learn on the decisions made by the previous discriminator training
loss_history.append(gan_loss_history.history['loss']) ## store the loss for the step
if step%10 == 0:
print(str(step)+' batches trained')
step=step+1
10 batches trained 20 batches trained 30 batches trained 40 batches trained 50 batches trained 60 batches trained 70 batches trained 80 batches trained 90 batches trained 100 batches trained 110 batches trained 120 batches trained 130 batches trained 140 batches trained 150 batches trained 160 batches trained 170 batches trained 180 batches trained 190 batches trained 200 batches trained 210 batches trained 220 batches trained 230 batches trained 240 batches trained 250 batches trained 260 batches trained 270 batches trained 280 batches trained 290 batches trained 300 batches trained
## plotting the loss history
## we see the loss has a decreasing trend but fluctuates with incresing runs
## this is due to the large variation in the inputs induced by the permutations of the input minority and majority batches
## and also the diverse cinvex combinations generrated by the generatr network on these permuted batches
## the permutations can be important while lerning from lesser amounts of data
run_range=range(1,runs+1)
plt.rcParams["figure.figsize"] = (16,10)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel('runs',fontsize=25)
plt.ylabel('loss', fontsize=25)
plt.plot(run_range, loss_history)
[<matplotlib.lines.Line2D at 0x205bfdf2a90>]
## plotting the PCA for the minority majority batches and the synthetic samples during the last training run
## note that the synthetic samples are generated far from the majority class yet closer to minority
samples_for_plot=np.concatenate([np.array(min_batch),conv_samples,np.array(maj_batch)])
lab=np.concatenate([np.zeros(neb),np.zeros(neb)+1,np.zeros(neb)+2])
## do PCA
pca = PCA(n_components=2)
pca.fit(samples_for_plot)
data_pca=X = pca.transform(samples_for_plot)
## plot
plt.rcParams["figure.figsize"] = (12,12)
colors=['r', 'b', 'g']
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel('runs',fontsize=25)
plt.ylabel('loss', fontsize=25)
classes = ['minority', 'synthetic minority', 'majority']
scatter=plt.scatter(data_pca[:,0], data_pca[:,1], c=lab, cmap='Set1')
plt.legend(handles=scatter.legend_elements()[0], labels=classes, fontsize=20)
plt.show()
## onehot encoding of the entire training labels
training_labels_oh=[]
for i in range(len(training_data)):
if i<len(features_1_trn):
training_labels_oh.append(np.array([1,0]))
else:
training_labels_oh.append(np.array([0,1]))
training_labels_oh=np.array(training_labels_oh)
training_labels_oh=tf.convert_to_tensor(training_labels_oh)
## tensor version of the training data
training_data=tf.convert_to_tensor(training_data)
training_data.shape
TensorShape([990, 8])
## after the first phase of training the discriminator can be used for classification
## it already learns to differentiate the convex minority points with majority points during the first training phase
y_pred_2d=maj_min_discriminator.predict(tf.convert_to_tensor(test_data))
## discretisation of the labels
y_pred=np.digitize(y_pred_2d[:,0], [.5])
## prediction shows a model with good recall and less precision
print(confusion_matrix(test_labels, y_pred))
print(f1_score(test_labels, y_pred))
[[422 55] [ 4 13]] 0.3058823529411765
def NMB_guided(data_min, neb, index):
## generate a minority neighbourhood batch for a particular minority sample
## we need this for minority data generation
## we will generate synthetic samples for each training data neighbourhood
## index -> index of the minority sample in a training data whose neighbourhood we want to obtain
## data_min -> minority class data
## neb -> oversampling neighbourhood
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import shuffle
neigh = NearestNeighbors(neb)
neigh.fit(data_min)
ind=index
nmbi=neigh.kneighbors([data_min[ind]],neb,return_distance=False)
nmbi=shuffle(nmbi)
nmb=features_1_trn[nmbi]
nmb=tf.convert_to_tensor(nmb[0])
return (nmb)
def generate_data_for_min_point(data_min,neb,index,synth_num):
## generate synth_num synthetic points for a particular minoity sample
## synth_num -> required number of data points that can be generated from a neighbourhood
## data_min -> minority class data
## neb -> oversampling neighbourhood
## index -> index of the minority sample in a training data whose neighbourhood we want to obtain
runs=int(synth_num/neb)+1
synth_set=[]
for run in range(runs):
batch=NMB_guided(data_min, neb, index)
synth_batch=conv_sample_generator.predict(batch)
for i in range(len(synth_batch)):
synth_set.append(synth_batch[i])
synth_set=synth_set[:synth_num]
synth_set=np.array(synth_set)
return(synth_set)
## roughly claculate the upper bound of the synthetic samples to be generated from each neighbourhood
synth_num=((len(features_0_trn)-len(features_1_trn))//len(features_1_trn))+1
## generate synth_num synthetic samples from each minority neighbourhood
synth_set=[]
for i in range(len(features_1_trn)):
synth_i=generate_data_for_min_point(features_1_trn,neb,i,synth_num)
for k in range(len(synth_i)):
synth_set.append(synth_i[k])
synth_set=synth_set[:(len(features_0_trn)-len(features_1_trn))] ## extract the exact number of synthetic samples needed to exactly balance the two classes
synth_set=np.array(synth_set)
## generate the data and labels for the oversampled data to visualize and retrain the discriminator for second phase learning
ovs_min_class=np.concatenate((features_1_trn,synth_set),axis=0)
ovs_training_dataset=np.concatenate((ovs_min_class,features_0_trn),axis=0)
ovs_pca_labels=np.concatenate((np.zeros(len(features_1_trn)),np.zeros(len(synth_set))+1,np.zeros(len(features_0_trn))+2))
ovs_training_labels=np.concatenate((np.zeros(len(ovs_min_class))+1,np.zeros(len(features_0_trn))+0))
ovs_training_labels_oh=[]
for i in range(len(ovs_training_dataset)):
if i<len(ovs_min_class):
ovs_training_labels_oh.append(np.array([1,0]))
else:
ovs_training_labels_oh.append(np.array([0,1]))
ovs_training_labels_oh=np.array(ovs_training_labels_oh)
ovs_training_labels_oh=tf.convert_to_tensor(ovs_training_labels_oh)
## PCA visualization of the synthetic sata
## observe how the minority samples from convex space have optimal variance and avoids overlap with the majority
pca = PCA(n_components=2)
pca.fit(ovs_training_dataset)
data_pca=X = pca.transform(ovs_training_dataset)
## plot PCA
plt.rcParams["figure.figsize"] = (12,12)
colors=['r', 'b', 'g']
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel('PCA1',fontsize=25)
plt.ylabel('PCA2', fontsize=25)
classes = ['minority', 'synthetic minority', 'majority']
scatter=plt.scatter(data_pca[:,0], data_pca[:,1], c=ovs_pca_labels, cmap='Set1')
plt.legend(handles=scatter.legend_elements()[0], labels=classes, fontsize=20)
plt.show()
## second phase training of the discriminator with balanced data
history_second_learning=maj_min_discriminator.fit(x=ovs_training_dataset,y=ovs_training_labels_oh, batch_size=20, epochs=300)
Epoch 1/300 96/96 [==============================] - 1s 3ms/step - loss: 0.3674 Epoch 2/300 96/96 [==============================] - 0s 3ms/step - loss: 0.3450 Epoch 3/300 96/96 [==============================] - 0s 3ms/step - loss: 0.3295 Epoch 4/300 96/96 [==============================] - 0s 3ms/step - loss: 0.3199 Epoch 5/300 96/96 [==============================] - 0s 3ms/step - loss: 0.3124 Epoch 6/300 96/96 [==============================] - 0s 4ms/step - loss: 0.3074 Epoch 7/300 96/96 [==============================] - 0s 3ms/step - loss: 0.3030 Epoch 8/300 96/96 [==============================] - 0s 3ms/step - loss: 0.3000 Epoch 9/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2969 Epoch 10/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2941 Epoch 11/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2912 Epoch 12/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2891 Epoch 13/300 96/96 [==============================] - 1s 5ms/step - loss: 0.2874 Epoch 14/300 96/96 [==============================] - 1s 5ms/step - loss: 0.2866 Epoch 15/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2842 Epoch 16/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2824 Epoch 17/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2810 Epoch 18/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2800 Epoch 19/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2774 Epoch 20/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2772 Epoch 21/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2765 Epoch 22/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2760 Epoch 23/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2730 Epoch 24/300 96/96 [==============================] - ETA: 0s - loss: 0.273 - 0s 5ms/step - loss: 0.2718 Epoch 25/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2703 Epoch 26/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2701 Epoch 27/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2680 Epoch 28/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2674 Epoch 29/300 96/96 [==============================] - 1s 6ms/step - loss: 0.2659 Epoch 30/300 96/96 [==============================] - 1s 6ms/step - loss: 0.2643 Epoch 31/300 96/96 [==============================] - 1s 6ms/step - loss: 0.2633 Epoch 32/300 96/96 [==============================] - 1s 6ms/step - loss: 0.2626 Epoch 33/300 96/96 [==============================] - 1s 6ms/step - loss: 0.2624 Epoch 34/300 96/96 [==============================] - 1s 6ms/step - loss: 0.2603 Epoch 35/300 96/96 [==============================] - 1s 6ms/step - loss: 0.2592 Epoch 36/300 96/96 [==============================] - 1s 6ms/step - loss: 0.2585 Epoch 37/300 96/96 [==============================] - 1s 6ms/step - loss: 0.2589 Epoch 38/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2572 Epoch 39/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2579 Epoch 40/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2545 Epoch 41/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2522 Epoch 42/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2522 Epoch 43/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2509 Epoch 44/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2521 Epoch 45/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2489 Epoch 46/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2482 Epoch 47/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2464 Epoch 48/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2451 Epoch 49/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2449 Epoch 50/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2440 Epoch 51/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2427 Epoch 52/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2424 Epoch 53/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2419A: 0s - loss: 0.2 Epoch 54/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2405 Epoch 55/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2396 Epoch 56/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2385 Epoch 57/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2386 Epoch 58/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2368 Epoch 59/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2350 Epoch 60/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2347 Epoch 61/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2333 Epoch 62/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2333 Epoch 63/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2318 Epoch 64/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2307 Epoch 65/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2320 Epoch 66/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2288 Epoch 67/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2285 Epoch 68/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2272 Epoch 69/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2274 Epoch 70/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2248 Epoch 71/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2247 Epoch 72/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2236 Epoch 73/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2235 Epoch 74/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2224 Epoch 75/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2219 Epoch 76/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2196 Epoch 77/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2200 Epoch 78/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2177 Epoch 79/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2197 Epoch 80/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2169 Epoch 81/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2169 Epoch 82/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2151 Epoch 83/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2163 Epoch 84/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2129 Epoch 85/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2141 Epoch 86/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2117 Epoch 87/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2108 Epoch 88/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2104 Epoch 89/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2098 Epoch 90/300 96/96 [==============================] - 0s 5ms/step - loss: 0.2087 Epoch 91/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2081 Epoch 92/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2069 Epoch 93/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2066 Epoch 94/300 96/96 [==============================] - 0s 4ms/step - loss: 0.2061 Epoch 95/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2052 Epoch 96/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2051 Epoch 97/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2032 Epoch 98/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2029 Epoch 99/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2029 Epoch 100/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2018 Epoch 101/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2005 Epoch 102/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2013 Epoch 103/300 96/96 [==============================] - 0s 3ms/step - loss: 0.2005 Epoch 104/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1995 Epoch 105/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1977 Epoch 106/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1980 Epoch 107/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1969 Epoch 108/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1954 Epoch 109/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1970 Epoch 110/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1950 Epoch 111/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1945 Epoch 112/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1938 Epoch 113/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1923 Epoch 114/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1925 Epoch 115/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1925 Epoch 116/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1909 Epoch 117/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1915 Epoch 118/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1904 Epoch 119/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1889 Epoch 120/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1899 Epoch 121/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1891 Epoch 122/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1868 Epoch 123/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1869 Epoch 124/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1848 Epoch 125/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1869 Epoch 126/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1845 Epoch 127/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1850 Epoch 128/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1838 Epoch 129/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1834 Epoch 130/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1837 Epoch 131/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1821 Epoch 132/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1827 Epoch 133/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1819 Epoch 134/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1832 Epoch 135/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1815 Epoch 136/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1790 Epoch 137/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1786 Epoch 138/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1790 Epoch 139/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1786 Epoch 140/300 96/96 [==============================] - 1s 6ms/step - loss: 0.1772 Epoch 141/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1768 Epoch 142/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1772 Epoch 143/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1768 Epoch 144/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1768 Epoch 145/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1748 Epoch 146/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1744 Epoch 147/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1736 Epoch 148/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1738 Epoch 149/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1736 Epoch 150/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1715 Epoch 151/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1725 Epoch 152/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1703 Epoch 153/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1714 Epoch 154/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1718 Epoch 155/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1713 Epoch 156/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1694 Epoch 157/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1681 Epoch 158/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1700 Epoch 159/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1694 Epoch 160/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1673 Epoch 161/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1676 Epoch 162/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1679 Epoch 163/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1660 Epoch 164/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1684 Epoch 165/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1670 Epoch 166/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1653 Epoch 167/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1646 Epoch 168/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1645 Epoch 169/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1638 Epoch 170/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1624 Epoch 171/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1646 Epoch 172/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1621 Epoch 173/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1626 Epoch 174/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1617 Epoch 175/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1608 Epoch 176/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1601 Epoch 177/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1639 Epoch 178/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1592 Epoch 179/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1584 Epoch 180/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1586 Epoch 181/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1588 Epoch 182/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1579 Epoch 183/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1578 Epoch 184/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1584 Epoch 185/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1565 Epoch 186/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1563 Epoch 187/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1575 Epoch 188/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1553 Epoch 189/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1544 Epoch 190/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1543 Epoch 191/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1546 Epoch 192/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1537 Epoch 193/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1538 Epoch 194/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1529 Epoch 195/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1550 Epoch 196/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1524 Epoch 197/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1518 Epoch 198/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1514 Epoch 199/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1510 Epoch 200/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1509 Epoch 201/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1516 Epoch 202/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1514 Epoch 203/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1500 Epoch 204/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1488 Epoch 205/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1477 Epoch 206/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1497 Epoch 207/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1484 Epoch 208/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1488 Epoch 209/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1478 Epoch 210/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1474 Epoch 211/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1458 Epoch 212/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1453 Epoch 213/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1460 Epoch 214/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1475 Epoch 215/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1454 Epoch 216/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1437 Epoch 217/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1444 Epoch 218/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1432 Epoch 219/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1432 Epoch 220/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1427 Epoch 221/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1431 Epoch 222/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1419 Epoch 223/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1424 Epoch 224/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1419 Epoch 225/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1423 Epoch 226/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1414 Epoch 227/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1411 Epoch 228/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1404 Epoch 229/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1393 Epoch 230/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1397 Epoch 231/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1378 Epoch 232/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1388 Epoch 233/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1379 Epoch 234/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1377 Epoch 235/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1376 Epoch 236/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1364 Epoch 237/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1353 Epoch 238/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1371 Epoch 239/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1370 Epoch 240/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1371 Epoch 241/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1344 Epoch 242/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1354 Epoch 243/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1345 Epoch 244/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1344 Epoch 245/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1341 Epoch 246/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1337 Epoch 247/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1331 Epoch 248/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1331A: 0s - loss: 0. Epoch 249/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1331 Epoch 250/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1351 Epoch 251/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1313 Epoch 252/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1324 Epoch 253/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1299 Epoch 254/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1296 Epoch 255/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1297 Epoch 256/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1306 Epoch 257/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1312 Epoch 258/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1306 Epoch 259/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1296 Epoch 260/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1318 Epoch 261/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1302 Epoch 262/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1312A: 0s - loss: 0. Epoch 263/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1267 Epoch 264/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1288 Epoch 265/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1285 Epoch 266/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1274 Epoch 267/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1285 Epoch 268/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1264 Epoch 269/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1252 Epoch 270/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1268 Epoch 271/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1265 Epoch 272/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1276 Epoch 273/300 96/96 [==============================] - 0s 3ms/step - loss: 0.1250 Epoch 274/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1253 Epoch 275/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1248 Epoch 276/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1251 Epoch 277/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1252 Epoch 278/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1246 Epoch 279/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1235 Epoch 280/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1267 Epoch 281/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1250 Epoch 282/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1231 Epoch 283/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1229 Epoch 284/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1231 Epoch 285/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1222 Epoch 286/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1225 Epoch 287/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1236 Epoch 288/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1203 Epoch 289/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1217 Epoch 290/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1197 Epoch 291/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1214 Epoch 292/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1219 Epoch 293/300 96/96 [==============================] - 1s 5ms/step - loss: 0.1198 Epoch 294/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1211 Epoch 295/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1203 Epoch 296/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1202 Epoch 297/300 96/96 [==============================] - 0s 4ms/step - loss: 0.1197 Epoch 298/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1207 Epoch 299/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1193 Epoch 300/300 96/96 [==============================] - 0s 5ms/step - loss: 0.1204
## loss of the second phase learning smoothly decreses
## this is because now the data is fixed and diverse convex combinations are no longer fed into the discriminator at every training step
run_range=range(1,300+1)
plt.rcParams["figure.figsize"] = (16,10)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.xlabel('runs',fontsize=25)
plt.ylabel('loss', fontsize=25)
plt.plot(run_range, history_second_learning.history['loss'])
[<matplotlib.lines.Line2D at 0x205bff3eca0>]
## finally after second phase training the discriminator classifier has a more balanced performance
## meaning better F1-Score
## the recall decreases but the precision improves
y_pred_2d=maj_min_discriminator.predict(tf.convert_to_tensor(test_data))
y_pred=np.digitize(y_pred_2d[:,0], [.5])
print(confusion_matrix(test_labels, y_pred))
print(f1_score(test_labels, y_pred))
[[448 29] [ 7 10]] 0.35714285714285715
What needs to be done next: